Skip to content

[Refactor][vLLM] Replace worker extension with extract_hidden_states KV connector#57

Merged
yubofredwang merged 2 commits intomainfrom
ywang/extract-hidden-states-connector
Apr 2, 2026
Merged

[Refactor][vLLM] Replace worker extension with extract_hidden_states KV connector#57
yubofredwang merged 2 commits intomainfrom
ywang/extract-hidden-states-connector

Conversation

@yubofredwang
Copy link
Copy Markdown
Collaborator

@yubofredwang yubofredwang commented Mar 30, 2026

Switch hidden states capture from monkey-patching model.forward via VllmWorkerExtension to vLLM's public extract_hidden_states speculative method paired with a custom MooncakeHiddenStatesConnector KV connector.

Key changes:

  • Add MooncakeHiddenStatesConnector that writes hidden states to Mooncake via RDMA directly from vLLM worker processes
  • Rewrite VllmEngine to use speculative_config + kv_transfer_config
  • Delete VllmWorkerExtension (774 lines) — no longer needed
  • Add verifier norm support to TargetLMHead and Eagle3Trainer for pre-norm hidden states (vLLM extract_hidden_states captures pre-norm outputs)
  • Add last_hidden_states_prenorm config with auto-detection per engine type
  • Add proper engine shutdown in training loop cleanup
  • Add output count mismatch guard in inference manager
  • Bump vLLM minimum to >=0.18.0
  • Update tests for new connector/engine

fixes issue: #53

@yubofredwang yubofredwang force-pushed the ywang/extract-hidden-states-connector branch from 8987560 to 7c4fb89 Compare March 30, 2026 09:46
@yubofredwang yubofredwang marked this pull request as ready for review March 30, 2026 09:46
Copilot AI review requested due to automatic review settings March 30, 2026 09:46
…MooncakeHiddenStatesConnector

Switch hidden states capture from monkey-patching model.forward via
VllmWorkerExtension to vLLM's public extract_hidden_states speculative
method paired with a custom MooncakeHiddenStatesConnector KV connector.

Key changes:
- Add MooncakeHiddenStatesConnector that writes hidden states to Mooncake
  via RDMA directly from vLLM worker processes
- Rewrite VllmEngine to use speculative_config + kv_transfer_config
- Delete VllmWorkerExtension (774 lines) — no longer needed
- Add verifier norm support to TargetLMHead and Eagle3Trainer for pre-norm
  hidden states (vLLM extract_hidden_states captures pre-norm outputs)
- Add last_hidden_states_prenorm config with auto-detection per engine type
- Add proper engine shutdown in training loop cleanup
- Add output count mismatch guard in inference manager
- Bump vLLM minimum to >=0.18.0
- Update tests for new connector/engine
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Refactors TorchSpec’s vLLM integration to capture hidden states via vLLM’s public extract_hidden_states speculative method and a custom KV connector, replacing the prior worker-extension monkey-patching approach.

Changes:

  • Introduces MooncakeHiddenStatesConnector to write hidden states directly to Mooncake from vLLM worker processes via KV transfer.
  • Updates VllmEngine to use speculative_config + kv_transfer_config, and removes the legacy VllmWorkerExtension.
  • Adjusts training/config/test plumbing for pre-norm last_hidden_states (verifier norm), adds safer shutdown/guards, and bumps vLLM minimum to >=0.18.0.

Reviewed changes

Copilot reviewed 14 out of 14 changed files in this pull request and generated 6 comments.

Show a summary per file
File Description
torchspec/inference/engine/vllm_engine.py Switches vLLM engine wiring to extract_hidden_states + KV transfer metadata flow.
torchspec/inference/engine/mooncake_hidden_states_connector.py Adds a KV connector that extracts hidden states from KV cache and stores to Mooncake.
torchspec/inference/engine/vllm_worker_extension.py Deletes the previous monkey-patching worker extension implementation.
torchspec/training/eagle3_trainer.py Applies optional verifier norm to pre-norm last_hidden_states before target projection.
torchspec/models/target/target_utils.py Adds optional loading/initialization of final norm alongside lm_head for pre-norm handling.
torchspec/controller/loop.py Ensures inference engines are shut down during training cleanup.
torchspec/controller/inference_manager.py Adds an output-count mismatch guard to avoid incorrect zipping/dispatch.
torchspec/controller/eval.py Adjusts initial eval submission sizing with inference_batch_size.
torchspec/config/train_config.py Auto-defaults last_hidden_states_prenorm based on engine type.
torchspec/config/inference_config.py Documents new vLLM behavior and adds last_hidden_states_prenorm resolution helper.
tests/test_vllm_engine.py Updates unit tests for new connector/metadata flow and adds chunked-prefill coverage.
pyproject.toml Bumps vLLM extra requirement to >=0.18.0.
configs/vllm_qwen3_8b.yaml Updates example config for the new connector-based capture path.
examples/data/sample_conversations.jsonl Adds an additional sample conversation entry used for long/chunked scenarios.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.


assert hs_shape[1] == num_training_layers * hidden_size
assert lhs_shape[1] == hidden_size
assert hs_shape[1] + lhs_shape[1] != (num_training_layers + 1) * hidden_size or True
Copy link

Copilot AI Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This assertion is a no-op: ... != ... or True always evaluates to True, so it can’t catch regressions in the hidden_states/last_hidden_states split. Replace it with a meaningful check (e.g., assert equality to the expected combined width, or assert the combined width equals num_hidden_states * hidden_size).

Suggested change
assert hs_shape[1] + lhs_shape[1] != (num_training_layers + 1) * hidden_size or True
assert hs_shape[1] + lhs_shape[1] == (num_training_layers + 1) * hidden_size

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 7c4fb89136

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

…errides

When pre-norm last hidden states are enabled, TargetLMHead always used
the default norm_key (model.norm.weight) because only lm_head_key was
forwarded from config. For models with custom key prefixes, norm loading
silently failed, leaving self.norm=None and corrupting target logits.
@yubofredwang yubofredwang merged commit ebf869e into main Apr 2, 2026
1 check passed
@yubofredwang yubofredwang deleted the ywang/extract-hidden-states-connector branch April 2, 2026 00:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants